{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "369c3444",
   "metadata": {},
   "source": [
    "# ReadtheDocs Retrieval Augmented Generation (RAG) using Milvus Docker Container"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6ffd11a",
   "metadata": {},
   "source": [
    "In this notebook, we are going to use Milvus documentation pages to create a chatbot about our product.  The chatbot is going to follow RAG steps to retrieve chunks of data using Semantic Vector Search, then the Question + Context will be fed as a Prompt to a LLM to generate an answer.\n",
    "\n",
    "Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model.  **In this notebook, we will demo a fully open source RAG stack.**\n",
    "\n",
    "Using open-source Q&A with retrieval saves money since we make free calls to our own data almost all the time - retrieval, evaluation, and development iterations.  We only make a paid call to OpenAI once for the final chat generation step. \n",
    "\n",
    "<div>\n",
    "<img src=\"../../../images/rag_image.png\" width=\"80%\"/>\n",
    "</div>\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b2509fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For colab install these libraries in this order:\n",
    "# !python -m pip install torch transformers sentence-transformers langchain\n",
    "# !python -m pip install -U pymilvus 'pymilvus[model]'\n",
    "# !python -m pip install unstructured openai tqdm numpy ipykernel \n",
    "# !python -m pip install ragas datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d7570b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import common libraries.\n",
    "import sys, os, time, pprint\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06d8a9d1",
   "metadata": {},
   "source": [
    "## Get and save API keys.\n",
    "\n",
    "The services I'm going to use are:\n",
    "- [Zilliz Cloud](https://cloud.zilliz.com/login)\n",
    "- [OpenAI](https://platform.openai.com/api-keys)\n",
    "- [Anthropic]()\n",
    "- [Anyscale endpoints]()\n",
    "\n",
    "💡 Note: To keep your tokens private, best practice is to use an **env variable**.  See [how to save api key in env variable](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety). <br>\n",
    "\n",
    "👉🏼 In Jupyter, you need a .env file (in same dir as notebooks) containing lines like this:\n",
    "- ZILLIZ_API_KEY=f370c...\n",
    "- OPENAI_API_KEY=sk-H...\n",
    "- VARIABLE_NAME=value..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e059b674",
   "metadata": {},
   "source": [
    "## Download Data\n",
    "\n",
    "The data used in this notebook is Milvus documentation web pages.\n",
    "\n",
    "The code block below downloads all the web pages into a local directory called `rtdocs`.  \n",
    "\n",
    "I've already uploaded the `rtdocs` data folder to github, so you should see it if you cloned my repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "25686cc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNCOMMENT TO DOWNLOAD THE DOCS.\n",
    "\n",
    "# # !pip install -U langchain\n",
    "# from langchain_community.document_loaders import RecursiveUrlLoader\n",
    "\n",
    "# DOCS_PAGE=\"https://milvus.io/docs/\"\n",
    "\n",
    "# loader = RecursiveUrlLoader(DOCS_PAGE)\n",
    "# docs = loader.load()\n",
    "\n",
    "# num_documents = len(docs)\n",
    "# print(f\"loaded {num_documents} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "83b232dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded 22 documents\n"
     ]
    }
   ],
   "source": [
    "# UNCOMMENT TO READ THE DOCS FROM A LOCAL DIRECTORY.\n",
    "\n",
    "# Read docs into LangChain\n",
    "# !pip install -U langchain\n",
    "# !pip install unstructured\n",
    "from langchain.document_loaders import DirectoryLoader\n",
    "\n",
    "# Load HTML files from a local directory\n",
    "path = \"../../RAG/rtdocs/\"\n",
    "loader = DirectoryLoader(path, glob='*.html')\n",
    "docs = loader.load()\n",
    "\n",
    "num_documents = len(docs)\n",
    "print(f\"loaded {num_documents} documents\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb844837",
   "metadata": {},
   "source": [
    "## Start up Milvus running in local Docker (or Zilliz free tier)\n",
    "\n",
    ">⛔️ Make sure you pip install the correct version of pymilvus and server yml file.  **Versions (major and minor) should all match**.\n",
    "\n",
    "1. [Install Docker](https://docs.docker.com/get-docker/)\n",
    "2. Start your Docker Desktop\n",
    "3. If you have an old Milvus container, previous version, delete it from Docker Desktop.\n",
    "4. Download the latest [docker-compose.yml](https://milvus.io/docs/install_standalone-docker.md#Download-the-YAML-file) (or run the wget command, replacing version to what you are using)\n",
    "> wget https://github.com/milvus-io/milvus/releases/download/v2.4.1/milvus-standalone-docker-compose.yml -O docker-compose.yml\n",
    "5. From your terminal:  \n",
    "   - cd into directory where you saved the .yml file (usualy same dir as this notebook)\n",
    "   - docker compose up -d\n",
    "   - verify (either in terminal or on Docker Desktop) the containers are running\n",
    "6. From your code (see notebook code below):\n",
    "   - You already did this (pip install -U pymilvus)\n",
    "   - Import milvus and connect to the local milvus server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "753e214a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -U pymilvus\n",
    "# !wget https://github.com/milvus-io/milvus/releases/download/v2.4.1/milvus-standalone-docker-compose.yml -O docker-compose.yml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "86786ab7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pymilvus: 2.4.1\n",
      "v2.4.1\n"
     ]
    }
   ],
   "source": [
    "# STEP 1. CONNECT TO MILVUS STANDALONE DOCKER.\n",
    "\n",
    "import pymilvus, time\n",
    "from pymilvus import (\n",
    "    MilvusClient, utility, connections,\n",
    "    FieldSchema, CollectionSchema, DataType, IndexType,\n",
    "    Collection, AnnSearchRequest, RRFRanker, WeightedRanker\n",
    ")\n",
    "print(f\"Pymilvus: {pymilvus.__version__}\")\n",
    "\n",
    "# Connect to the local server.\n",
    "connection = connections.connect(\n",
    "  alias=\"default\", \n",
    "  host='localhost', # or '0.0.0.0' or 'localhost'\n",
    "  port='19530'\n",
    ")\n",
    "\n",
    "# Get server version.\n",
    "print(utility.get_server_version())\n",
    "\n",
    "# Use no-schema Milvus client uses flexible json key:value format.\n",
    "mc = MilvusClient(connections=connection)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9d758e4",
   "metadata": {},
   "source": [
    "# Optionally, use Zilliz free tier cluster\n",
    "To use fully-managed Milvus on [Ziliz Cloud free trial](https://cloud.zilliz.com/login).  \n",
    "  1. Choose the default \"Starter\" option and accept the default Cloud Provider and Region when you create a cluster. \n",
    "  2. On the Cluster main page, copy your `API Key` and store it locally in a .env variable.  See [this note](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety) how to do that.\n",
    "  3. Also on the Cluster main page, copy the `Public Endpoint URI` and store it somewhere convenient.\n",
    "  4. Jupyter also requires them in a local .env file. <br>\n",
    "Anywhere in the bootcamp directory, create a .env file\n",
    "Insert lines like this, substituting your actual API keys for the sample text: <br>\n",
    "ZILLIZ_API_KEY=f370c <br>\n",
    "OPENAI_API_KEY=sk-H <br>\n",
    "ANYSCALE_ENPOINT_KEY=es <br>\n",
    "ANTHROPIC_API_KEY=sk-an <br>\n",
    "VARIABLE_NAME=value <br>\n",
    "Save the .env file <br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0806d2db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # STEP 1. CONNECT TO ZILLIZ CLOUD\n",
    "# import os\n",
    "# import pymilvus\n",
    "# print(f\"pymilvus version: {pymilvus.__version__}\")\n",
    "# from pymilvus import connections, utility\n",
    "# TOKEN = os.getenv(\"ZILLIZ_API_KEY\")\n",
    "\n",
    "# # Connect to Zilliz cloud using endpoint URI and API key TOKEN.\n",
    "# # TODO change this.\n",
    "# CLUSTER_ENDPOINT=\"https://in03-xxxx.api.gcp-us-west1.zillizcloud.com:443\"\n",
    "# CLUSTER_ENDPOINT=\"https://in03-48a5b11fae525c9.api.gcp-us-west1.zillizcloud.com:443\"\n",
    "# connections.connect(\n",
    "#   alias='default',\n",
    "#   #  Public endpoint obtained from Zilliz Cloud\n",
    "#   uri=CLUSTER_ENDPOINT,\n",
    "#   # API key or a colon-separated cluster username and password\n",
    "#   token=TOKEN,\n",
    "# )\n",
    "\n",
    "# # Use no-schema Milvus client uses flexible json key:value format.\n",
    "# # https://milvus.io/docs/using_milvusclient.md\n",
    "# mc = MilvusClient(\n",
    "#     uri=CLUSTER_ENDPOINT,\n",
    "#     # API key or a colon-separated cluster username and password\n",
    "#     token=TOKEN)\n",
    "\n",
    "# # Check if the server is ready and get colleciton name.\n",
    "# print(f\"Type of server: {utility.get_server_version()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f39af3fd",
   "metadata": {},
   "source": [
    "## Load the Embedding Model checkpoint and use it to create vector embeddings\n",
    "\n",
    "#### What are Embeddings?\n",
    "\n",
    "Check out [this blog](https://zilliz.com/glossary/vector-embeddings) for an introduction to embeddings.  \n",
    "\n",
    "An excellent place to start is by selecting an embedding model from the [HuggingFace MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard), sorted descending by the \"Retrieval Average'' column since this task is most relevant to RAG. Then, choose the smallest, highest-ranking embedding model. But, Beware!! some models listed are overfit to the training data, so they won't perform on your data as promised.  \n",
    "\n",
    "Milvus (and Zilliz) only supports tested embedding models that are not overfit."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b01d6622",
   "metadata": {},
   "source": [
    "#### In this notebook, we will use the **open-source BGE-M3** which supports: \n",
    "- over 100 languages\n",
    "- context lengths of up to 8192\n",
    "- multiple embedding inferences such as dense (semantic), sparse (lexical), and multi-vector Colbert reranking. \n",
    "\n",
    "BGE-M3 holds the distinction of being the first embedding model to offer support for all three retrieval methods, achieving state-of-the-art performance on multi-lingual (MIRACL) and cross-lingual (MKQA) benchmark tests.  [Paper](https://arxiv.org/abs/2402.03216), [HuggingFace](https://huggingface.co/BAAI/bge-m3)\n",
    "\n",
    "**[Milvus](https://github.com/milvus-io/milvus)**, the world's first Open Source Vector Database, plays a vital role in semantic search with scaleable, efficient storage and search for GenerativeAI workflows. Its advanced functionalities include metadata filtering and hybrid search.  Since version 2.4, Milvus has built-in support for BGE M3.\n",
    "\n",
    "<div>\n",
    "<img src=\"../../../images/bge_m3.png\" width=\"80%\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1805f966",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cpu\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d12db900e6724666a8a7de437b637ed8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dense_dim: 1024\n",
      "sparse_dim: 250002\n",
      "colbert_dim: 1024\n"
     ]
    }
   ],
   "source": [
    "# STEP 2. USE A MILVUS BUILT-IN OPEN SOURCE EMBEDDING MODEL.\n",
    "\n",
    "from pymilvus.model.hybrid import BGEM3EmbeddingFunction\n",
    "import torch\n",
    "\n",
    "# Initialize torch settings\n",
    "DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"device: {DEVICE}\")\n",
    "\n",
    "# Initialize a Milvus built-in sparse-dense-reranking encoder.\n",
    "# https://huggingface.co/BAAI/bge-m3\n",
    "embedding_model = BGEM3EmbeddingFunction(use_fp16=False, device=DEVICE)\n",
    "EMBEDDING_DIM = embedding_model.dim['dense']\n",
    "print(f\"dense_dim: {EMBEDDING_DIM}\")\n",
    "print(f\"sparse_dim: {embedding_model.dim['sparse']}\")\n",
    "print(f\"colbert_dim: {embedding_model.dim['colbert_vecs']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9609497f",
   "metadata": {},
   "source": [
    "## Create a Milvus collection\n",
    "\n",
    "You can think of a collection in Milvus like a \"table\" in SQL databases.  The **collection** will contain the \n",
    "- **Schema** (or [no-schema Milvus client](https://milvus.io/docs/using_milvusclient.md)).  \n",
    "💡 You'll need the vector `EMBEDDING_DIM` parameter from your embedding model.\n",
    "Typical values are:\n",
    "   - 1024 for sbert embedding models\n",
    "   - 1536 for ada-002 OpenAI embedding models\n",
    "- **Vector index** for efficient vector search\n",
    "- **Vector distance metric** for measuring nearest neighbor vectors\n",
    "- **Consistency level**\n",
    "In Milvus, transactional consistency is possible; however, according to the [CAP theorem](https://en.wikipedia.org/wiki/CAP_theorem), some latency must be sacrificed. 💡 Searching movie reviews is not mission-critical, so [`eventually`](https://milvus.io/docs/consistency.md) consistent is fine here.\n",
    "\n",
    "## Add a Vector Index\n",
    "\n",
    "The vector index determines the vector **search algorithm** used to find the closest vectors in your data to the query a user submits.  \n",
    "\n",
    "Most vector indexes use different sets of parameters depending on whether the database is:\n",
    "- **inserting vectors** (creation mode) - vs - \n",
    "- **searching vectors** (search mode) \n",
    "\n",
    "Scroll down the [docs page](https://milvus.io/docs/index.md) to see a table listing different vector indexes available on Milvus.  For example:\n",
    "- FLAT - deterministic exhaustive search\n",
    "- IVF_FLAT or IVF_SQ8 - Hash index (stochastic approximate search)\n",
    "- HNSW - Graph index (stochastic approximate search)\n",
    "- AUTOINDEX - OSS or [Zilliz cloud](https://docs.zilliz.com/docs/autoindex-explained) automatic index based on type of GPU, size of data.\n",
    "\n",
    "Besides a search algorithm, we also need to specify a **distance metric**, that is, a definition of what is considered \"close\" in vector space.  In the cell below, the [`HNSW`](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) search index is chosen.  Its possible distance metrics are one of:\n",
    "- L2 - L2-norm\n",
    "- IP - Dot-product\n",
    "- COSINE - Angular distance\n",
    "\n",
    "💡 Most use cases work better with normalized embeddings, in which case L2 is useless (every vector has length=1) and IP and COSINE are the same.  Only choose L2 if you plan to keep your embeddings unnormalized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "559dc3ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully dropped collection: `MilvusDocs`\n",
      "Successfully created collection: `MilvusDocs`\n"
     ]
    }
   ],
   "source": [
    "# STEP 3. CREATE A NO-SCHEMA MILVUS COLLECTION AND DEFINE THE DATABASE INDEX.\n",
    "# See docstrings for more information.\n",
    "# https://github.com/milvus-io/pymilvus/blob/master/examples/hello_hybrid_sparse_dense.py\n",
    "\n",
    "from pymilvus import MilvusClient\n",
    "\n",
    "# Set the Milvus collection name.\n",
    "COLLECTION_NAME = \"MilvusDocs\"\n",
    "\n",
    "# Specify the data schema for the new Collection.\n",
    "MAX_LENGTH = 65535\n",
    "fields = [\n",
    "    # Use auto generated id as primary key\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64,\n",
    "                is_primary=True, auto_id=True, max_length=100),\n",
    "    FieldSchema(name=\"sparse_vector\", dtype=DataType.SPARSE_FLOAT_VECTOR),\n",
    "    FieldSchema(name=\"dense_vector\", dtype=DataType.FLOAT_VECTOR,\n",
    "                dim=EMBEDDING_DIM),\n",
    "    FieldSchema(name=\"chunk\", dtype=DataType.VARCHAR, max_length=MAX_LENGTH),\n",
    "    FieldSchema(name=\"source\", dtype=DataType.VARCHAR, max_length=MAX_LENGTH),\n",
    "    FieldSchema(name=\"h1\", dtype=DataType.VARCHAR, max_length=100),\n",
    "    FieldSchema(name=\"h2\", dtype=DataType.VARCHAR, max_length=MAX_LENGTH),\n",
    "]\n",
    "schema = CollectionSchema(fields, \"\")\n",
    "\n",
    "# Check if collection already exists, if so drop it.\n",
    "has = utility.has_collection(COLLECTION_NAME)\n",
    "if has:\n",
    "    drop_result = utility.drop_collection(COLLECTION_NAME)\n",
    "    print(f\"Successfully dropped collection: `{COLLECTION_NAME}`\")\n",
    "\n",
    "# Create the collection.\n",
    "schema = CollectionSchema(fields, \"\")\n",
    "col = Collection(COLLECTION_NAME, schema, consistency_level=\"Eventually\")\n",
    "\n",
    "# Add custom HNSW search index to the collection.\n",
    "# M = max number graph connections per layer. Large M = denser graph.\n",
    "# Choice of M: 4~64, larger M for larger data and larger embedding lengths.\n",
    "M = 16\n",
    "# efConstruction = num_candidate_nearest_neighbors per layer. \n",
    "# Use Rule of thumb: int. 8~512, efConstruction = M * 2.\n",
    "efConstruction = M * 2\n",
    "# Create the search index for local Milvus server.\n",
    "INDEX_PARAMS = dict({\n",
    "    'M': M,               \n",
    "    \"efConstruction\": efConstruction })\n",
    "\n",
    "# Create indices for the vector fields. \n",
    "# The indices will pre-load data into memory for efficient search.\n",
    "sparse_index = {\"index_type\": \"SPARSE_INVERTED_INDEX\", \"metric_type\": \"IP\"}\n",
    "dense_index = {\"index_type\": \"HNSW\", \"metric_type\": \"COSINE\", \"params\": INDEX_PARAMS}\n",
    "col.create_index(\"sparse_vector\", sparse_index)\n",
    "col.create_index(\"dense_vector\", dense_index)\n",
    "col.load()\n",
    "\n",
    "print(f\"Successfully created collection: `{COLLECTION_NAME}`\")\n",
    "# print(mc.describe_collection(COLLECTION_NAME))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c60423a5",
   "metadata": {},
   "source": [
    "## HTML Chunking\n",
    "\n",
    "Before embedding, it is necessary to decide your chunk strategy, chunk size, and chunk overlap.  This section uses:\n",
    "- **Strategy** = Use markdown header hierarchies.  Keep markdown sections together unless they are too long.\n",
    "- **Chunk size** = Use the embedding model's parameter `MAX_SEQ_LENGTH`\n",
    "- **Overlap** = Rule-of-thumb 10-15%\n",
    "- **Function** = \n",
    "  - Langchain's `HTMLHeaderTextSplitter` to split markdown sections.\n",
    "  - Langchain's `RecursiveCharacterTextSplitter` to split up long reviews recursively.\n",
    "\n",
    "\n",
    "Notice below, each chunk is grounded with the document source page.  <br>\n",
    "In addition, header titles are kept together with the chunk of markdown text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5ab9cd1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "chunk_size: 512, chunk_overlap: 51.0\n",
      "chunking time: 0.02252483367919922\n",
      "docs: 22, split into: 22\n",
      "split into chunks: 304, type: list of <class 'langchain_core.documents.base.Document'>\n",
      "\n",
      "Looking at a sample chunk...\n",
      "Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed Milvus FREE Search Home v2.4.x Abo\n",
      "{'h1': 'Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed Milvus FREE Search Home v2.4.x Abo', 'source': '../../RAG/rtdocs/quickstart.html'}\n"
     ]
    }
   ],
   "source": [
    "# # STEP 4. PREPARE DATA: CHUNK AND EMBED\n",
    "# !python -m pip install lxml\n",
    "from langchain_community.document_transformers import BeautifulSoupTransformer\n",
    "from langchain.text_splitter import HTMLHeaderTextSplitter, RecursiveCharacterTextSplitter\n",
    "import numpy as np\n",
    "import pprint\n",
    "\n",
    "# Define chunk size 512 and overlap 10% chunk_size.\n",
    "chunk_size = 512\n",
    "chunk_overlap = np.round(chunk_size * 0.10, 0)\n",
    "print(f\"chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}\")\n",
    "\n",
    "# Create an instance of the RecursiveCharacterTextSplitter\n",
    "child_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size = chunk_size,\n",
    "    chunk_overlap = chunk_overlap,\n",
    "    length_function = len,  # using built-in Python len function\n",
    ")\n",
    "\n",
    "# Define the headers to split on for the HTMLHeaderTextSplitter\n",
    "headers_to_split_on = [\n",
    "    (\"h1\", \"Header 1\"),\n",
    "    (\"h2\", \"Header 2\"),\n",
    "]\n",
    "# Create an instance of the HTMLHeaderTextSplitter\n",
    "html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)\n",
    "\n",
    "# Split the HTML text using the HTMLHeaderTextSplitter.\n",
    "start_time = time.time()\n",
    "html_header_splits = []\n",
    "for doc in docs:\n",
    "    splits = html_splitter.split_text(doc.page_content)\n",
    "    for split in splits:\n",
    "        # Add the source URL and header values to the metadata\n",
    "        metadata = {}\n",
    "        new_text = split.page_content\n",
    "        for header_name, metadata_header_name in headers_to_split_on:\n",
    "            # Handle exception if h1 does not exist.\n",
    "            try:\n",
    "                header_value = new_text.split(\"¶ \")[0].strip()[:100]\n",
    "                metadata[header_name] = header_value\n",
    "            except:\n",
    "                break\n",
    "            # Handle exception if h2 does not exist.\n",
    "            try:\n",
    "                new_text = new_text.split(\"¶ \")[1].strip()[:50]\n",
    "            except:\n",
    "                break\n",
    "        split.metadata = {\n",
    "            **metadata,\n",
    "            \"source\": doc.metadata[\"source\"]\n",
    "        }\n",
    "        # Add the header to the text\n",
    "        split.page_content = split.page_content\n",
    "    html_header_splits.extend(splits)\n",
    "\n",
    "    # # TODO - Uncomment to save each doc.page_content as a local html file under OUTPUT_DIR.\n",
    "    # OUTPUT_DIR = \"output\"\n",
    "    # # Set filename to first 50 characters of h1 header.\n",
    "    # filename = doc.metadata[\"source\"].split(\"/\")[-1].split(\".\")[0][:50]\n",
    "    # with open(f\"{OUTPUT_DIR}/{filename}.html\", \"w\") as f:\n",
    "    #     f.write(doc.page_content)\n",
    "\n",
    "# Split the documents further into smaller, recursive chunks.\n",
    "chunks = child_splitter.split_documents(html_header_splits)\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"chunking time: {end_time - start_time}\")\n",
    "print(f\"docs: {len(docs)}, split into: {len(html_header_splits)}\")\n",
    "print(f\"split into chunks: {len(chunks)}, type: list of {type(chunks[0])}\") \n",
    "\n",
    "# Inspect a chunk.\n",
    "print()\n",
    "print(\"Looking at a sample chunk...\")\n",
    "print(chunks[0].page_content[:100])\n",
    "print(chunks[0].metadata)\n",
    "\n",
    "# # TODO - Uncomment to print child splits with their associated header metadata for debugging.\n",
    "# print()\n",
    "# for child in chunks:\n",
    "#     print(f\"Content: {child.page_content}\")\n",
    "#     print(f\"Metadata: {child.metadata}\")\n",
    "#     print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "512130a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed Milvus FREE Search Home v2.4.x Abo\n",
      "{'h1': 'Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed Milvus FREE Search Home v2.4.x Abo', 'source': 'https://milvus.io/docs/quickstart.md'}\n"
     ]
    }
   ],
   "source": [
    "# Clean up the metadata urls\n",
    "for doc in chunks:\n",
    "    new_url = doc.metadata[\"source\"]\n",
    "    new_url = new_url.replace(\"../../RAG/rtdocs\", \"https://milvus.io/docs\")\n",
    "    new_url = new_url.replace(\".html\", \".md\")\n",
    "    doc.metadata.update({\"source\": new_url})\n",
    "\n",
    "print(chunks[0].page_content[:100])\n",
    "print(chunks[0].metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "229daa80",
   "metadata": {},
   "source": [
    "Use the built-in Milvus BGE M3 embedding functions.  The output will be 2 vectors:\n",
    "- `embeddings['dense'][i]` is a list of numpy arrays, one per chunk. Milvus supports more than 1 dense embedding vector if desired, so i is the ith dense embedding vector.\n",
    "- `embeddings['sparse'][:, [i]]` is a scipy sparse matrix where each column represents a chunk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d223c6f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Inference Embeddings: 100%|██████████| 19/19 [00:35<00:00,  1.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding time for 304 chunks: 35.11 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# STEP 5. TRANSFORM CHUNKS INTO VECTORS USING EMBEDDING MODEL INFERENCE.\n",
    "\n",
    "# BGEM3EmbeddingFunction input is docs as a list of strings.\n",
    "list_of_strings = [doc.page_content for doc in chunks if hasattr(doc, 'page_content')]\n",
    "\n",
    "# Embedding inference using the Milvus built-in sparse-dense-reranking encoder.\n",
    "start_time = time.time()\n",
    "embeddings = embedding_model(list_of_strings)\n",
    "end_time = time.time()\n",
    "\n",
    "print(f\"Embedding time for {len(list_of_strings)} chunks: \", end=\"\")\n",
    "print(f\"{np.round(end_time - start_time, 2)} seconds\")\n",
    "\n",
    "# Inference Embeddings: 100%|██████████| 19/19 [00:35<00:00,  1.86s/it]\n",
    "# Embedding time for 304 chunks: 35.74 seconds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9bd8153",
   "metadata": {},
   "source": [
    "## Insert data into Milvus\n",
    "\n",
    "For each original text chunk, we'll write the sextuplet (`chunk, h1, h2, source, dense_vector, sparse_vector`) into the database.\n",
    "\n",
    "<div>\n",
    "<img src=\"../../../images/db_insert_sparse_dense.png\" width=\"80%\"/>\n",
    "</div>\n",
    "\n",
    "**The Milvus Client wrapper can only handle loading data from a list of dictionaries.**\n",
    "\n",
    "Otherwise, in general, Milvus supports loading data from:\n",
    "- pandas dataframes \n",
    "- list of dictionaries\n",
    "\n",
    "Below, we use the embedding model provided by HuggingFace, download its checkpoint, and run it locally as the encoder.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "79dd2299",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "304\n",
      "<class 'dict'> 6\n",
      "{'chunk': 'Why Milvus Docs Tutorials Tools Blog Community Stars0 Try Managed '\n",
      "          'Milvus FREE Search Home v2.4.x About Milvus Get '\n",
      "          'StartedPrerequisitesInstall MilvusInstall SDKsQuickstart Concepts '\n",
      "          'User Guide Embeddings Administration Guide Tools Integrations '\n",
      "          'Example Applications FAQs API reference Quickstart This guide '\n",
      "          'explains how to connect to your Milvus cluster and performs CRUD '\n",
      "          'operations in minutes Before you start You have installed Milvus '\n",
      "          'standalone or Milvus cluster. You have installed preferred SDKs. '\n",
      "          'You can',\n",
      " 'dense_vector': array([-0.01666467,  0.05284622, -0.05246124, ..., -0.0182556 ,\n",
      "        0.03670057, -0.00945159], dtype=float32),\n",
      " 'h1': 'Why Milvus Docs Tutorials Tools Blog Community Sta',\n",
      " 'h2': '',\n",
      " 'source': 'https://milvus.io/docs/quickstart.md',\n",
      " 'sparse_vector': <1x250002 sparse array of type '<class 'numpy.float32'>'\n",
      "\twith 63 stored elements in Compressed Sparse Row format>}\n"
     ]
    }
   ],
   "source": [
    "# STEP 6. INSERT CHUNK LIST INTO MILVUS OR ZILLIZ.\n",
    "\n",
    "# Create chunk_list and dict_list in a single loop\n",
    "dict_list = []\n",
    "for chunk, sparse, dense in zip(chunks, embeddings[\"sparse\"], embeddings[\"dense\"]):\n",
    "    # Assemble embedding vector, original text chunk, metadata.\n",
    "    chunk_dict = {\n",
    "        'chunk': chunk.page_content,\n",
    "        'h1': chunk.metadata.get('h1', \"\")[:50],\n",
    "        'h2': chunk.metadata.get('h2', \"\")[:50],\n",
    "        'source': chunk.metadata.get('source', \"\"),\n",
    "        'sparse_vector': sparse,\n",
    "        'dense_vector': dense\n",
    "    }\n",
    "    dict_list.append(chunk_dict)\n",
    "\n",
    "# TODO - Uncomment to inspect the first chunk and its metadata.\n",
    "print(len(dict_list))\n",
    "print(type(dict_list[0]), len(dict_list[0]))\n",
    "pprint.pprint(dict_list[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f3ac0d5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start inserting entities\n",
      "Milvus insert time for 304 vectors: 0.19 seconds\n"
     ]
    }
   ],
   "source": [
    "# Insert data into the Milvus collection.\n",
    "print(\"Start inserting entities\")\n",
    "start_time = time.time()\n",
    "col.insert(dict_list)\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"Milvus insert time for {len(dict_list)} vectors: \", end=\"\")\n",
    "print(f\"{np.round(end_time - start_time, 2)} seconds\")\n",
    "col.flush()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd834ae6",
   "metadata": {},
   "source": [
    "## Aside - example Milvus collection API calls\n",
    "https://milvus.io/docs/manage-collections.md#View-Collections\n",
    "\n",
    "Below are some common API calls for checking a collection.\n",
    "- `.num_entities`, flushes data and executes row count.\n",
    "- `.describe_collection()`, gives details about the schema, index, collection.\n",
    "- `.query()`, gives back selected data from the collection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "aa628f3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count rows: 304\n",
      "timing: 0.002 seconds\n",
      "\n",
      "{'aliases': [],\n",
      " 'auto_id': True,\n",
      " 'collection_id': 449601743893922125,\n",
      " 'collection_name': 'MilvusDocs',\n",
      " 'consistency_level': 3,\n",
      " 'description': '',\n",
      " 'enable_dynamic_field': False,\n",
      " 'fields': [{'auto_id': True,\n",
      "             'description': '',\n",
      "             'field_id': 100,\n",
      "             'is_primary': True,\n",
      "             'name': 'id',\n",
      "             'params': {},\n",
      "             'type': <DataType.INT64: 5>},\n",
      "            {'description': '',\n",
      "             'field_id': 101,\n",
      "             'name': 'sparse_vector',\n",
      "             'params': {},\n",
      "             'type': <DataType.SPARSE_FLOAT_VECTOR: 104>},\n",
      "            {'description': '',\n",
      "             'field_id': 102,\n",
      "             'name': 'dense_vector',\n",
      "             'params': {'dim': 1024},\n",
      "             'type': <DataType.FLOAT_VECTOR: 101>},\n",
      "            {'description': '',\n",
      "             'field_id': 103,\n",
      "             'name': 'chunk',\n",
      "             'params': {'max_length': 65535},\n",
      "             'type': <DataType.VARCHAR: 21>},\n",
      "            {'description': '',\n",
      "             'field_id': 104,\n",
      "             'name': 'source',\n",
      "             'params': {'max_length': 65535},\n",
      "             'type': <DataType.VARCHAR: 21>},\n",
      "            {'description': '',\n",
      "             'field_id': 105,\n",
      "             'name': 'h1',\n",
      "             'params': {'max_length': 100},\n",
      "             'type': <DataType.VARCHAR: 21>},\n",
      "            {'description': '',\n",
      "             'field_id': 106,\n",
      "             'name': 'h2',\n",
      "             'params': {'max_length': 65535},\n",
      "             'type': <DataType.VARCHAR: 21>}],\n",
      " 'num_partitions': 1,\n",
      " 'num_shards': 1,\n",
      " 'properties': {}}\n",
      "timing: 0.0016 seconds\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Example Milvus Collection utility API calls.\n",
    "# https://milvus.io/docs/manage-collections.md#View-Collections\n",
    "\n",
    "# Count rows.\n",
    "start_time = time.time()\n",
    "print(f\"Count rows: {col.num_entities}\")\n",
    "end_time = time.time()\n",
    "print(f\"timing: {np.round(end_time - start_time, 4)} seconds\")\n",
    "print()\n",
    "\n",
    "# View collection info, incurs a call to .flush() first.\n",
    "start_time = time.time()\n",
    "pprint.pprint(mc.describe_collection(COLLECTION_NAME))\n",
    "end_time = time.time()\n",
    "print(f\"timing: {np.round(end_time - start_time, 4)} seconds\")\n",
    "print()\n",
    "\n",
    "# View rows. Careful, this can be a lot of output!\n",
    "# OUTPUT_FIELDS = [\"id\", \"h1\", \"h2\", \"source\", \"chunk\"]\n",
    "# res = mc.query( collection_name=COLLECTION_NAME, \n",
    "#                filter=\"id <= 449197422118227014\", \n",
    "#                output_fields = OUTPUT_FIELDS, )\n",
    "# pprint.pprint(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02c589ff",
   "metadata": {},
   "source": [
    "## Ask a question about your data\n",
    "\n",
    "So far in this demo notebook: \n",
    "1. Your custom data has been mapped into a vector embedding space\n",
    "2. Those vector embeddings have been saved into a vector database\n",
    "\n",
    "Next, you can ask a question about your custom data!\n",
    "\n",
    "💡 In LLM vocabulary:\n",
    "> **Query** is the generic term for user questions.  \n",
    "A query is a list of multiple individual questions, up to maybe 1000 different questions!\n",
    "\n",
    "> **Question** usually refers to a single user question.  \n",
    "In our example below, the user question is \"What is AUTOINDEX in Milvus Client?\"\n",
    "\n",
    "> **Semantic Search** = very fast search of the entire knowledge base to find the `TOP_K` documentation chunks with the closest embeddings to the user's query.\n",
    "\n",
    "💡 The same model should always be used for consistency for all the embeddings data and the query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5e7f41f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "query length: 75\n"
     ]
    }
   ],
   "source": [
    "# Define a sample question about your data.\n",
    "QUESTION1 = \"What do the parameters for HNSW mean?\"\n",
    "QUESTION2 = \"What are good default values for HNSW parameters with 25K vectors dim 1024?\"\n",
    "QUESTION3 = \"What is the default AUTOINDEX distance metric in Milvus Client?\"\n",
    "QUESTION4 = \"What does nlist mean in ivf_flat?\"\n",
    "\n",
    "# In case you want to ask all the questions at once.\n",
    "QUERY = [QUESTION1, QUESTION2, QUESTION3, QUESTION4]\n",
    "\n",
    "# Inspect the length of one question.\n",
    "QUERY_LENGTH = len(QUESTION2)\n",
    "print(f\"query length: {QUERY_LENGTH}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cd25ffca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SELECT A PARTICULAR QUESTION TO ASK.\n",
    "\n",
    "SAMPLE_QUESTION = QUESTION1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ea29411",
   "metadata": {},
   "source": [
    "## Execute a vector search\n",
    "\n",
    "Search Milvus using [PyMilvus API](https://milvus.io/docs/search.md).\n",
    "\n",
    "💡 By their nature, vector searches are \"semantic\" searches.  For example, if you were to search for \"leaky faucet\": \n",
    "> **Traditional Key-word Search** - either or both words \"leaky\", \"faucet\" would have to match some text in order to return a web page or link text to the document.\n",
    "\n",
    "> **Semantic search** - results containing words \"drippy\" \"taps\" would be returned as well because these words mean the same thing even though they are different words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2bcf6cdc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Milvus Client search time for 304 vectors: 0.00729680061340332 seconds\n",
      "type: <class 'pymilvus.client.abstract.Hits'>, count: 2\n"
     ]
    }
   ],
   "source": [
    "# STEP 7. RETRIEVE ANSWERS FROM YOUR DOCUMENTS STORED IN MILVUS OR ZILLIZ.\n",
    "\n",
    "# Load the index into memory for search.\n",
    "col.load()\n",
    "\n",
    "# Embed the question using the same encoder.\n",
    "query_embeddings = embedding_model([SAMPLE_QUESTION])\n",
    "TOP_K = 2\n",
    "\n",
    "# Return top k results with HNSW index.\n",
    "SEARCH_PARAMS = dict({\n",
    "    # Re-use index param for num_candidate_nearest_neighbors.\n",
    "    \"ef\": INDEX_PARAMS['efConstruction']\n",
    "    })\n",
    "\n",
    "# Prepare the search requests for both vector fields\n",
    "sparse_search_params = {\"metric_type\": \"IP\"}\n",
    "sparse_req = AnnSearchRequest(\n",
    "                query_embeddings[\"sparse\"],\n",
    "                \"sparse_vector\", sparse_search_params, limit=TOP_K)\n",
    "\n",
    "dense_search_params = {\"metric_type\": \"COSINE\"}\n",
    "dense_search_params.update(SEARCH_PARAMS)\n",
    "dense_req = AnnSearchRequest(\n",
    "                query_embeddings[\"dense\"],\n",
    "                \"dense_vector\", dense_search_params, limit=TOP_K)\n",
    "\n",
    "# Define output fields to return.\n",
    "OUTPUT_FIELDS = [\"id\", \"h1\", \"h2\", \"source\", \"chunk\"]\n",
    "\n",
    "# Run semantic vector search using your query and the vector database.\n",
    "start_time = time.time()\n",
    "# Use the reranker.\n",
    "results = col.hybrid_search([\n",
    "            sparse_req, dense_req], rerank=RRFRanker(),\n",
    "            limit=TOP_K, output_fields=OUTPUT_FIELDS)\n",
    "# # No reranking use dense only.\n",
    "# results = col.hybrid_search([\n",
    "#             sparse_req, dense_req], rerank=WeightedRanker(0., 1.0),\n",
    "#             limit=TOP_K, output_fields=OUTPUT_FIELDS)\n",
    "\n",
    "elapsed_time = time.time() - start_time\n",
    "print(f\"Milvus Client search time for {len(dict_list)} vectors: {elapsed_time} seconds\")\n",
    "\n",
    "# Inspect search result.\n",
    "print(f\"type: {type(results[0])}, count: {len(results[0])}\")\n",
    "\n",
    "# Currently Milvus only support 1 query in the same hybrid search request, so\n",
    "# we inspect res[0] directly. In future release Milvus will accept batch\n",
    "# hybrid search queries in the same call.\n",
    "results = results[0]\n",
    "\n",
    "# Milvus Client search time for 304 vectors: 0.02100086212158203 seconds\n",
    "# type: <class 'pymilvus.client.abstract.Hits'>, count: 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa3cade1",
   "metadata": {},
   "source": [
    "## Assemble and inspect the search result\n",
    "\n",
    "The search result is in the variable `results[0]` consisting of top_k-count of objects of type `'pymilvus.client.abstract.Hits'`\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "06076f7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Retrieved result #449601743892518433\n",
      "distance = 0.032522473484277725\n",
      "\n",
      "Retrieved result #449601743892518434\n",
      "distance = 0.032522473484277725\n",
      "\n",
      "[(449601743892518433,\n",
      "  0.032522473484277725,\n",
      "  'this value can improve recall rate at the cost of increased search time. '\n",
      "  '[1, 65535] 2 HNSW HNSW (Hierarchical Navigable Small World Graph) is a '\n",
      "  'graph-based indexing algorithm. It builds a multi-layer navigation '\n",
      "  'structure for an image according to certain rules. In this structure, the '\n",
      "  'upper layers are more sparse and the distances between nodes are farther; '\n",
      "  'the lower layers are denser and the distances between nodes are closer. The '\n",
      "  'search starts from the uppermost layer, finds the node closest to the '\n",
      "  'target',\n",
      "  'https://milvus.io/docs/index.md',\n",
      "  {'id': 449601743892518433}),\n",
      " (449601743892518434,\n",
      "  0.032522473484277725,\n",
      "  'layer, finds the node closest to the target in this layer, and then enters '\n",
      "  'the next layer to begin another search. After multiple iterations, it can '\n",
      "  'quickly approach the target position. In order to improve performance, HNSW '\n",
      "  'limits the maximum degree of nodes on each layer of the graph to M. In '\n",
      "  'addition, you can use efConstruction (when building index) or ef (when '\n",
      "  'searching targets) to specify a search range. Index building parameters '\n",
      "  'Parameter Description Range M Maximum degree of the node (2, 2048)',\n",
      "  'https://milvus.io/docs/index.md',\n",
      "  {'h1': 'Why Milvus Docs Tutorials Tools Blog Community Sta'})]\n"
     ]
    }
   ],
   "source": [
    "# Assemble retrieved context and context metadata.\n",
    "METADATA_FIELDS = [f for f in OUTPUT_FIELDS if f != 'chunk']\n",
    "\n",
    "# Assemble retrieved ids, distances, contexts, sources, and metadata.\n",
    "ids = []\n",
    "distances = []\n",
    "contexts = []\n",
    "sources = []\n",
    "metas = []\n",
    "for i in range(len(results)):\n",
    "    print(f\"Retrieved result #{results[i].id}\")\n",
    "    ids.append(results[i].id)\n",
    "    print(f\"distance = {results[i].distance}\")\n",
    "    distances.append(results[i].distance)\n",
    "    # print(f\"Context: {results[i].entity.chunk[:150]}\")\n",
    "    contexts.append(results[i].entity.chunk)\n",
    "    for j in METADATA_FIELDS:\n",
    "        if hasattr(results[i].entity, j):\n",
    "            meta_dict = {j: getattr(results[i].entity, j)}\n",
    "            metas.append(meta_dict)\n",
    "            if j == \"source\":\n",
    "                # print(f\"{j}: {getattr(results[i].entity, j)}\")\n",
    "                sources.append(getattr(results[i].entity, j))\n",
    "    print()\n",
    "\n",
    "# Keep results in a list of tuples.\n",
    "formatted_results = list(zip(ids, distances, contexts, sources, metas)) \n",
    "\n",
    "# TODO - Uncomment to print the results.\n",
    "pprint.pprint(formatted_results)\n",
    "\n",
    "\n",
    "# Reranking: I only see differences in positions 7-10, when k=10.\n",
    "# Rerank = True\n",
    "# Retrieved result #449197422118225569\n",
    "# distance = 0.032522473484277725\n",
    "\n",
    "# Retrieved result #449197422118225570\n",
    "# distance = 0.032522473484277725\n",
    "\n",
    "# Retrieved result #449197422118225710\n",
    "# distance = 0.03079839050769806\n",
    "\n",
    "# Retrieved result #449197422118225553\n",
    "# distance = 0.03077651560306549\n",
    "\n",
    "# Retrieved result #449197422118225771\n",
    "# distance = 0.03077651560306549\n",
    "\n",
    "# Retrieved result #449197422118225565\n",
    "# distance = 0.03030998818576336\n",
    "\n",
    "# Retrieved result #449197422118225604\n",
    "# distance = 0.01587301678955555\n",
    "\n",
    "# Retrieved result #449197422118225631\n",
    "# distance = 0.015384615398943424\n",
    "\n",
    "# Retrieved result #449197422118225568\n",
    "\n",
    "# Rerank = False\n",
    "# Retrieved result #449197422118225569\n",
    "# distance = 0.39207690954208374\n",
    "\n",
    "# Retrieved result #449197422118225570\n",
    "# distance = 0.36890196800231934\n",
    "\n",
    "# Retrieved result #449197422118225710\n",
    "# distance = 0.25794607400894165\n",
    "\n",
    "# Retrieved result #449197422118225553\n",
    "# distance = 0.25781089067459106\n",
    "\n",
    "# Retrieved result #449197422118225771\n",
    "# distance = 0.25772351026535034\n",
    "\n",
    "# Retrieved result #449197422118225565\n",
    "# distance = 0.2537841796875\n",
    "\n",
    "# Retrieved result #449197422118225604\n",
    "# distance = 0.2157345414161682\n",
    "\n",
    "# Retrieved result #449197422118225575\n",
    "# distance = 0.2065277099609375\n",
    "\n",
    "# Retrieved result #449197422118225562\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd6060ce",
   "metadata": {},
   "source": [
    "## Use an LLM to Generate a chat response to the user's question using the Retrieved Context.\n",
    "\n",
    "Many different generative LLMs exist these days.  Check out the lmsys [leaderboard](https://chat.lmsys.org/?leaderboard).\n",
    "\n",
    "In this notebook, we'll try these LLMs:\n",
    "- The newly released open-source Llama 3 from Meta.\n",
    "- The cheapest, paid model from Anthropic Claude3 Haiku.\n",
    "- The standard in its price cateogory, gpt-3.5-turbo, from Openai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "eb4c323f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length long text to summarize: 1017\n"
     ]
    }
   ],
   "source": [
    "# STEP 8. LLM-GENERATED ANSWER TO THE QUESTION, GROUNDED BY RETRIEVED CONTEXT.\n",
    "\n",
    "# Separate all the context together by space.\n",
    "# Lance Martin, LangChain, says put best contexts at end.\n",
    "contexts_combined = ' '.join(reversed(contexts))\n",
    "# Separate all the sources together by comma.\n",
    "source_combined = ' '.join(reversed(sources))\n",
    "print(f\"Length long text to summarize: {len(contexts_combined)}\")\n",
    "\n",
    "# Define temperature for the LLM and random seed.\n",
    "TEMPERATURE = 0.1\n",
    "RANDOM_SEED = 415\n",
    "FREQUENCY_PENALTY = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11fb35aa",
   "metadata": {},
   "source": [
    "# Try Llama3 with Ollama to generate a human-like chat response to the user's question\n",
    "\n",
    "Follow the instructions to install ollama and pull a model.<br>\n",
    "https://github.com/ollama/ollama\n",
    "\n",
    "View details about which models are supported by ollama. <br>\n",
    "https://ollama.com/library/llama3\n",
    "\n",
    "That page says `ollama run llama3` will by default pull the latest \"instruct\" model, which is fine-tuned for chat/dialogue use cases.\n",
    "\n",
    "The other kind of llama3 models are \"pre-trained\" base model. <br>\n",
    "Example: ollama run llama3:text ollama run llama3:70b-text\n",
    "\n",
    "**Format** `gguf` means the model runs on CPU.  gg = \"Georgi Gerganov\", creator of the C library model format ggml, which was recently changed to gguf.\n",
    "\n",
    "**Quantization** (think of it like vector compaction) can lead to higher throughput at the expense of lower accuracy.  For the curious, quantization meanings can be found on: <br>\n",
    "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/tree/main.  \n",
    "\n",
    "Below just listing the main quantization types.\n",
    "- **q4_0**: Original quant method, 4-bit.\n",
    "- **q4_k_m**: Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K\n",
    "- **q5_0**: Higher accuracy, higher resource usage and slower inference.\n",
    "- **q5_k_m**: Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K\n",
    "- **q 6_k**: Uses Q8_K for all tensors\n",
    "- **q8_0**: Almost indistinguishable from float16. High resource use and slow. Not recommended for most users."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0edc67e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MODEL:llama3:latest, FORMAT:gguf, PARAMETER_SIZE:8B, QUANTIZATION_LEVEL:Q4_0, \n",
      "\n"
     ]
    }
   ],
   "source": [
    "# !python -m pip install ollama\n",
    "import ollama\n",
    "\n",
    "# Verify details which model you are running.\n",
    "ollama_llama3 = ollama.list()['models'][0]\n",
    "\n",
    "# Print the model details.\n",
    "keys = ['format', 'parameter_size', 'quantization_level']\n",
    "print(f\"MODEL:{ollama.list()['models'][0]['name']}\", end=\", \")\n",
    "for key in keys:\n",
    "    print(f\"{str.upper(key)}:{ollama.list()['models'][0]['details'].get(key, 'Key not found in dictionary')}\", end=\", \")\n",
    "print(end=\"\\n\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "1c900282",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = f\"\"\"Given the provided Context, your task is to \n",
    "understand the content and accurately answer the question based \n",
    "on the information available in the context.  \n",
    "Provide a complete, clear, concise, relevant response in fewer\n",
    "than 4 sentences and cite the unique Sources.\n",
    "Answer: The answer to the question.\n",
    "Sources: {source_combined}\n",
    "Context: {contexts_combined}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "76042c9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('According to the provided context, the parameters for HNSW (Hierarchical '\n",
      " 'Navigable Small World Graph) are:  * `M`: Maximum degree of nodes on each '\n",
      " 'layer of the graph. This value can improve recall rate at the cost of '\n",
      " 'increased search time. \\t+ Range: (2, 2048) * `ef` (or `efConstruction` when '\n",
      " 'building the index): Search range parameter. \\t+ Range: [1, 65535]  These '\n",
      " 'parameters allow you to control the trade-off between recall rate and search '\n",
      " 'time when using HNSW for indexing and searching.')\n"
     ]
    }
   ],
   "source": [
    "# Send the question to llama 3 chat.\n",
    "response = ollama.chat(\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": SYSTEM_PROMPT,},\n",
    "        {\"role\": \"user\", \"content\": f\"question: {SAMPLE_QUESTION}\",}\n",
    "    ],\n",
    "    model='llama3',\n",
    "    options={\"temperature\": TEMPERATURE, \"seed\": RANDOM_SEED,}\n",
    ")\n",
    "pprint.pprint(response['message']['content'].replace('\\n', ' '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ad0ec552",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ('According to the provided context, the parameters for HNSW (Hierarchical '\n",
    "#  'Navigable Small World Graph) are:  * `M`: Maximum degree of nodes on each '\n",
    "#  'layer of the graph. This value can improve recall rate at the cost of '\n",
    "#  'increased search time. \\t+ Range: (2, 2048) * `ef` (or `efConstruction` when '\n",
    "#  'building the index): Search range parameter. \\t+ Range: [1, 65535]  These '\n",
    "#  'parameters allow you to control the trade-off between recall rate and search '\n",
    "#  'time when using HNSW for indexing and searching.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fd2b2dd",
   "metadata": {},
   "source": [
    "## Use Anthropic to generate a human-like chat response to the user's question \n",
    "\n",
    "We've practiced retrieval for free on our own data using open-source LLMs.  <br>\n",
    "\n",
    "Now let's make a call to the paid Claude3. [List of models](https://docs.anthropic.com/claude/docs/models-overview)\n",
    "- Opus - most expensive\n",
    "- Sonnet\n",
    "- Haiku - least expensive!\n",
    "\n",
    "Prompt engineering tutorials\n",
    "- [Interactive](https://docs.google.com/spreadsheets/d/19jzLgRruG9kjUQNKtCg1ZjdD6l6weA6qRXG5zLIAhC8/edit#gid=150872633)\n",
    "- [Static](https://docs.google.com/spreadsheets/d/1jIxjzUWG-6xBVIa2ay6yDpLyeuOh_hR_ZB75a47KX_E/edit#gid=869808629)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "edf66e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = f\"\"\"Use the Context below to answer the user's question. \n",
    "Be clear, factual, complete, concise.\n",
    "If the answer is not in the Context, say \"I don't know\". \n",
    "Otherwise answer with fewer than 4 sentences and cite the unique sources.\n",
    "Context: {contexts_combined}\n",
    "Sources: {source_combined}\n",
    "\n",
    "Answer with 2 parts: the answer and the source citations.\n",
    "Answer: The answer to the question.\n",
    "Sources: unique url sources\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c87b8428",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # !python -m pip install anthropic\n",
    "# import anthropic\n",
    "\n",
    "# ANTHROPIC_API_KEY=os.environ.get(\"ANTHROPIC_API_KEY\")\n",
    "\n",
    "# # # Model names\n",
    "# # claude-3-opus-20240229\n",
    "# # claude-3-sonnet-20240229\n",
    "# # claude-3-haiku-20240307\n",
    "# CLAUDE_MODEL = \"claude-3-haiku-20240307\"\n",
    "# print(f\"Model: {CLAUDE_MODEL}\")\n",
    "# print()\n",
    "\n",
    "# client = anthropic.Anthropic(\n",
    "#     # defaults to os.environ.get(\"ANTHROPIC_API_KEY\")\n",
    "#     api_key=ANTHROPIC_API_KEY,\n",
    "# )\n",
    "\n",
    "# # Print the question and answer along with grounding sources and citations.\n",
    "# print(f\"Question: {SAMPLE_QUESTION}\")\n",
    "\n",
    "# # CAREFUL!! THIS COSTS MONEY!!\n",
    "# message = client.messages.create(\n",
    "#     model=CLAUDE_MODEL,\n",
    "#     max_tokens=1000,\n",
    "#     temperature=0.0,\n",
    "#     system=SYSTEM_PROMPT,\n",
    "#     messages=[\n",
    "#         {\"role\": \"user\", \"content\": SAMPLE_QUESTION}\n",
    "#     ]\n",
    "# )\n",
    "# print(\"Answer:\")\n",
    "# pprint.pprint(message.content[0].text.replace('\\n', ' '))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5c1c9758",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # model=\"claude-3-haiku-20240307\"\n",
    "\n",
    "# # Question: What do the parameters for HNSW mean?\n",
    "# Answer:\n",
    "# ('The parameters for HNSW (Hierarchical Navigable Small World Graph) are:  1. '\n",
    "#  'M: This is the maximum degree of the nodes in the graph. It controls the '\n",
    "#  'sparsity of the upper layers and the density of the lower layers. The range '\n",
    "#  'for M is (2, 2048).  2. efConstruction: This parameter specifies the search '\n",
    "#  'range when building the index. It affects the recall rate and search time - '\n",
    "#  'a higher value can improve recall at the cost of increased search time.  3. '\n",
    "#  'ef: This parameter specifies the search range when searching for targets. '\n",
    "#  'Similar to efConstruction, a higher value can improve recall but increase '\n",
    "#  'search time.  Sources: [1] https://milvus.io/docs/index.md')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "17f138a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # model=\"claude-3-sonnet-20240229\"\n",
    "\n",
    "# # Question: What do the parameters for HNSW mean?\n",
    "# # Answer:\n",
    "# ('The parameters M and ef/efConstruction control the behavior of the HNSW '\n",
    "#  '(Hierarchical Navigable Small World) algorithm used for indexing and '\n",
    "#  'searching.  M specifies the maximum number of connections (edges) that each '\n",
    "#  'node in the HNSW graph can have. A higher M value allows more connections, '\n",
    "#  'which can improve recall rate (finding more relevant results) but increases '\n",
    "#  'search time.  ef and efConstruction determine how many nodes in each layer '\n",
    "#  'of the HNSW graph should be explored during searching and index construction '\n",
    "#  'respectively. Higher values increase the search range and can improve '\n",
    "#  'accuracy, but also increase computation time.  Sources: [1] '\n",
    "#  'https://milvus.io/docs/index.md [2] https://milvus.io/docs/index.md')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "704d7900",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # model=\"claude-3-opus-20240229\"\n",
    "\n",
    "# # Question: What do the parameters for HNSW mean?\n",
    "# # Answer:\n",
    "# ('According to the context, the HNSW algorithm has two key parameters:  1. M: '\n",
    "#  'The maximum degree of the node on each layer of the graph, which can be set '\n",
    "#  'between 2 and 2048. [1]  2. efConstruction (when building index) or ef (when '\n",
    "#  'searching targets): These parameters specify the search range to improve '\n",
    "#  'performance. [1]  Grounding sources: [1] https://milvus.io/docs/index.md')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8193430b",
   "metadata": {},
   "source": [
    "<div>\n",
    "<img src=\"../../../images/anthropic_claude3.png\" width=\"80%\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1facf70",
   "metadata": {},
   "source": [
    "## Try MistralAI's Mixtral 8x22B-Instruct-v0.1 to generate a human-like chat response to the user's question \n",
    "\n",
    "This time ollama's version requires 48GB RAM. If you have big enough compute, run the command:\n",
    "> ollama run mixtral\n",
    "\n",
    "Since my laptop is a M2 with only 16GB RAM, I decided to **run Mixtral using Anyscale Endpoints**. Instructions to install. <br>\n",
    "> https://github.com/simonw/llm-anyscale-endpoints\n",
    "\n",
    "To get back to **Anyscale Endpoints** anytime, open the playground.<br>\n",
    "https://console.anyscale.com/v2/playground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "f9d4380a",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = f\"\"\"Use the Context below to answer the user's question. \n",
    "Be complete, clear, concise, relevant.\n",
    "If the answer is not in the Context, say \"I don't know\". \n",
    "Otherwise answer with fewer than 4 sentences \n",
    "and cite only unique grounding sources.\n",
    "Answer: The answer to the question.\n",
    "Non-unique grounding sources: {source_combined}\n",
    "Context: {contexts_combined}\n",
    "\"\"\"\n",
    "# print(SYSTEM_PROMPT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c302f41f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAMPLE_QUESTION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "cc124a50",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # model=\"Mixtral 8x22B-Instruct-v0.1\"\n",
    "\n",
    "# # Question: What do the parameters for HNSW mean?\n",
    "# # Answer:\n",
    "# The parameters for HNSW (Hierarchical Navigable Small World Graph) include M, which represents the \n",
    "# maximum degree of the node in the HNSW graph, ranging from 2 to 2048. A higher value of M can \n",
    "# improve recall rate at the cost of increased search time. Additionally, efConstruction (used \n",
    "# during index building) and ef (used during target search) parameters are used to specify a search range.\n",
    "# Reference(s):\n",
    "# https://milvus.io/docs/index.md"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecc532a8",
   "metadata": {},
   "source": [
    "<div>\n",
    "<img src=\"../../../images/mistral_mixtral.png\" width=\"80%\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e172726",
   "metadata": {},
   "source": [
    "## Try OpenAI to generate a human-like chat response to the user's question \n",
    "\n",
    "We've practiced retrieval for free on our own data using open-source LLMs.  <br>\n",
    "\n",
    "Now let's make a call to the paid OpenAI GPT.\n",
    "\n",
    "💡 Note: For use cases that need to always be factually grounded, use very low temperature values while more creative tasks can benefit from higher temperatures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "426d87d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "SYSTEM_PROMPT = f\"\"\"Use the Context below to answer the user's question. \n",
    "Be complete, clear, concise, relevant.\n",
    "If the answer is not in the Context, say \"I don't know\". \n",
    "Otherwise answer with fewer than 4 sentences and cite the grounding sources.\n",
    "Answer: The answer to the question.\n",
    "Grounding sources: {source_combined}\n",
    "Context: {contexts_combined}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "76a62feb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: What do the parameters for HNSW mean?\n",
      "('Answer: The parameters for HNSW are as follows:\\n'\n",
      " '- M: Maximum degree of the node (2, 2048), which limits the maximum '\n",
      " 'connections a node can have in each layer.\\n'\n",
      " '- efConstruction: Used during index building to specify a search range, '\n",
      " 'improving recall rate at the cost of increased search time.\\n'\n",
      " '- ef: Used when searching targets to specify a search range. It determines '\n",
      " 'how many nodes are visited in each layer during the search process.')\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# CAREFUL!! THIS COSTS MONEY!!\n",
    "import openai, pprint\n",
    "from openai import OpenAI\n",
    "\n",
    "# 1. Define the generation llm model to use.\n",
    "# https://openai.com/blog/new-embedding-models-and-api-updates\n",
    "# Customers using the pinned gpt-3.5-turbo model alias will be automatically upgraded to gpt-3.5-turbo-0125 two weeks after this model launches.\n",
    "LLM_NAME = \"gpt-3.5-turbo\"\n",
    "\n",
    "# 2. Get your API key: https://platform.openai.com/api-keys\n",
    "# 3. Save your api key in env variable.\n",
    "# https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety\n",
    "openai_client = OpenAI(\n",
    "    # This is the default and can be omitted\n",
    "    api_key=os.environ.get(\"OPENAI_API_KEY\"),\n",
    ")\n",
    "\n",
    "# 4. Generate response using the OpenAI API.\n",
    "response = openai_client.chat.completions.create(\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": SYSTEM_PROMPT,},\n",
    "        {\"role\": \"user\", \"content\": f\"question: {SAMPLE_QUESTION}\",}\n",
    "    ],\n",
    "    model=LLM_NAME,\n",
    "    temperature=TEMPERATURE,\n",
    "    seed=RANDOM_SEED,\n",
    "    frequency_penalty=2,\n",
    ")\n",
    "\n",
    "# Print the question and answer along with grounding sources and citations.\n",
    "print(f\"Question: {SAMPLE_QUESTION}\")\n",
    "\n",
    "# 5. Print all answers in the response.\n",
    "for i, choice in enumerate(response.choices, 1):\n",
    "    pprint.pprint(f\"Answer: {choice.message.content}\")\n",
    "    print(\"\\n\")\n",
    "\n",
    "# Question1: What do the parameters for HNSW mean?\n",
    "# Answer:  Looks perfect!\n",
    "# Best answer:  M: maximum degree of nodes in a layer of the graph. \n",
    "# efConstruction: number of nearest neighbors to consider when connecting nodes in the graph.\n",
    "# ef: number of nearest neighbors to consider when searching for similar vectors. \n",
    "\n",
    "# Question2: What are good default values for HNSW parameters with 25K vectors dim 1024?\n",
    "# Answer: M=16, efConstruction=500, and ef=64\n",
    "# Best answer:  M=16, efConstruction=32, ef=32\n",
    "\n",
    "# Question3: what is the default distance metric used in AUTOINDEX in Milvus?\n",
    "# Answer: L2 \n",
    "# Trick answer:  IP inner product, not yet updated in documentation still says L2.\n",
    "\n",
    "# Question4: What does nlist mean in ivf_flat?\n",
    "# 'Answer: In IVF_FLAT, nlist refers to the number of cluster units that divide '\n",
    "#  'a vector space. When using the default value of 16384 for nlist in Milvus, '\n",
    "#  \"distances between the target vector and all 16384 clusters' centers are \"\n",
    "#  'compared to find the nearest clusters for further comparison with vectors '\n",
    "#  'within those selected clusters. This parameter influences how clustering is '\n",
    "#  'performed and affects search efficiency in Milvus.\\n'\n",
    "#  'Sources: https://milvus.io/docs/index.md')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9af34809",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model=\"gpt-3.5-turbo\"\n",
    "\n",
    "# Question: What do the parameters for HNSW mean?\n",
    "# ('Answer: The parameters for HNSW are M, which is the maximum degree of nodes '\n",
    "#  'on each layer of the graph (ranging from 2 to 2048), and efConstruction '\n",
    "#  '(used during index building) or ef (used during target searching) to specify '\n",
    "#  'a search range. The parameter M can improve recall rate at the expense of '\n",
    "#  'increased search time, while ef controls how exhaustive the search will be '\n",
    "#  'within a certain distance. These parameters help optimize performance in '\n",
    "#  'navigating through multi-layer structures efficiently.\\n'\n",
    "#  'Grounding sources: https://milvus.io/docs/index.md')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb8aeeba",
   "metadata": {},
   "source": [
    "## Use Ragas to evaluate RAG pipeline\n",
    "\n",
    "Ragas is an open source project for evaluating RAG components.  [Paper](https://arxiv.org/abs/2309.15217), [Code](https://docs.ragas.io/en/stable/getstarted/index.html), [Docs](https://docs.ragas.io/en/stable/getstarted/index.html), [Intro blog](https://medium.com/towards-data-science/rag-evaluation-using-ragas-4645a4c6c477).\n",
    "\n",
    "<div>\n",
    "<img src=\"../../../images/ragas_eval_image.png\" width=\"80%\"/>\n",
    "</div>\n",
    "\n",
    "**Please note that RAGAS can use a large amount of OpenAI api token consumption.** <br> \n",
    "\n",
    "Read through this notebook carefully and pay attention to the number of questions and metrics you want to evaluate.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "e1097990",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Question</th>\n",
       "      <th>ground_truth_answer</th>\n",
       "      <th>Sources</th>\n",
       "      <th>Custom_RAG_context</th>\n",
       "      <th>Simple</th>\n",
       "      <th>Custom_RAG_answer</th>\n",
       "      <th>llama3_answer</th>\n",
       "      <th>anthropic_claud3_haiku_answer</th>\n",
       "      <th>mixtral_8x22b_instruct</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What do the parameters for HNSW mean?</td>\n",
       "      <td># M: maximum degree, or number of connections ...</td>\n",
       "      <td>https://milvus.io/docs/index.md</td>\n",
       "      <td>HNSW (Hierarchical Navigable Small World Graph...</td>\n",
       "      <td>HNSW (Hierarchical Navigable Small World Graph...</td>\n",
       "      <td>The parameters for HNSW are as follows:\\n- M: ...</td>\n",
       "      <td>The parameters for HNSW (Hierarchical Navigabl...</td>\n",
       "      <td>The parameters for HNSW (Hierarchical Navigabl...</td>\n",
       "      <td>The parameter M in HNSW represents the maximum...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are good default values for HNSW paramete...</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>https://milvus.io/docs/index.md, https://milvu...</td>\n",
       "      <td>parameters vary with Milvus distribution. Sele...</td>\n",
       "      <td>nbits [Optional] Number of bits in which each ...</td>\n",
       "      <td>M=16, efConstruction=500, and ef=64</td>\n",
       "      <td>HNSW (Hierarchical Navigable Small World) inde...</td>\n",
       "      <td>I don't know. The context provided does not co...</td>\n",
       "      <td>Based on the provided Context, good default va...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>What does nlist vs nprobe mean in ivf_flat?</td>\n",
       "      <td># nlist:  controls how the vector data is part...</td>\n",
       "      <td>https://milvus.io/docs/index.md</td>\n",
       "      <td>FAQ  What is the difference between FLAT index...</td>\n",
       "      <td>index? IVF_FLAT index divides a vector space i...</td>\n",
       "      <td>In IVF_FLAT, nlist refers to the number of clu...</td>\n",
       "      <td>`nlist` refers to the number of cluster units ...</td>\n",
       "      <td>The nlist parameter in the IVF_FLAT index in M...</td>\n",
       "      <td>nlist, which stands for \"number of cluster uni...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>What is the default AUTOINDEX distance metric ...</td>\n",
       "      <td>Trick answer:  IP inner product, not yet updat...</td>\n",
       "      <td>https://milvus.io/docs/index.md</td>\n",
       "      <td>FAQs  API reference  Similarity Metrics  In Mi...</td>\n",
       "      <td>FAQs  API reference  Similarity Metrics  In Mi...</td>\n",
       "      <td>The default distance metric for the AUTOINDEX ...</td>\n",
       "      <td>According to the provided context, the answer ...</td>\n",
       "      <td>The default AUTOINDEX distance metric in Milvu...</td>\n",
       "      <td>The default distance metric for AUTOINDEX in M...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            Question  \\\n",
       "0              What do the parameters for HNSW mean?   \n",
       "1  What are good default values for HNSW paramete...   \n",
       "2        What does nlist vs nprobe mean in ivf_flat?   \n",
       "3  What is the default AUTOINDEX distance metric ...   \n",
       "\n",
       "                                 ground_truth_answer  \\\n",
       "0  # M: maximum degree, or number of connections ...   \n",
       "1                     M=16, efConstruction=32, ef=32   \n",
       "2  # nlist:  controls how the vector data is part...   \n",
       "3  Trick answer:  IP inner product, not yet updat...   \n",
       "\n",
       "                                             Sources  \\\n",
       "0                    https://milvus.io/docs/index.md   \n",
       "1  https://milvus.io/docs/index.md, https://milvu...   \n",
       "2                    https://milvus.io/docs/index.md   \n",
       "3                    https://milvus.io/docs/index.md   \n",
       "\n",
       "                                  Custom_RAG_context  \\\n",
       "0  HNSW (Hierarchical Navigable Small World Graph...   \n",
       "1  parameters vary with Milvus distribution. Sele...   \n",
       "2  FAQ  What is the difference between FLAT index...   \n",
       "3  FAQs  API reference  Similarity Metrics  In Mi...   \n",
       "\n",
       "                                              Simple  \\\n",
       "0  HNSW (Hierarchical Navigable Small World Graph...   \n",
       "1  nbits [Optional] Number of bits in which each ...   \n",
       "2  index? IVF_FLAT index divides a vector space i...   \n",
       "3  FAQs  API reference  Similarity Metrics  In Mi...   \n",
       "\n",
       "                                   Custom_RAG_answer  \\\n",
       "0  The parameters for HNSW are as follows:\\n- M: ...   \n",
       "1                M=16, efConstruction=500, and ef=64   \n",
       "2  In IVF_FLAT, nlist refers to the number of clu...   \n",
       "3  The default distance metric for the AUTOINDEX ...   \n",
       "\n",
       "                                       llama3_answer  \\\n",
       "0  The parameters for HNSW (Hierarchical Navigabl...   \n",
       "1  HNSW (Hierarchical Navigable Small World) inde...   \n",
       "2  `nlist` refers to the number of cluster units ...   \n",
       "3  According to the provided context, the answer ...   \n",
       "\n",
       "                       anthropic_claud3_haiku_answer  \\\n",
       "0  The parameters for HNSW (Hierarchical Navigabl...   \n",
       "1  I don't know. The context provided does not co...   \n",
       "2  The nlist parameter in the IVF_FLAT index in M...   \n",
       "3  The default AUTOINDEX distance metric in Milvu...   \n",
       "\n",
       "                              mixtral_8x22b_instruct  \n",
       "0  The parameter M in HNSW represents the maximum...  \n",
       "1  Based on the provided Context, good default va...  \n",
       "2  nlist, which stands for \"number of cluster uni...  \n",
       "3  The default distance metric for AUTOINDEX in M...  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os, sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import ragas, datasets\n",
    "from langchain_community.embeddings import HuggingFaceEmbeddings\n",
    "from ragas.embeddings import LangchainEmbeddingsWrapper\n",
    "\n",
    "# Import custom functions for evaluation.\n",
    "sys.path.append(\"../../Integration\")  \n",
    "import eval_ragas as _eval_ragas\n",
    "\n",
    "# Import the evaluation metrics.\n",
    "from ragas.metrics import (\n",
    "    context_recall, \n",
    "    context_precision, \n",
    "    faithfulness, \n",
    "    answer_relevancy, \n",
    "    answer_similarity,\n",
    "    answer_correctness\n",
    "    )\n",
    "\n",
    "# Get the current working directory.\n",
    "cwd = os.getcwd()\n",
    "relative_path = '/../../Evaluation/data/ground_truth_answers.csv'\n",
    "file_path = cwd + relative_path\n",
    "# print(f\"file_path: {file_path}\")\n",
    "\n",
    "# Read ground truth answers from file.\n",
    "eval_df = pd.read_csv(file_path, header=0, skip_blank_lines=True)\n",
    "display(eval_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "8ae8d2b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "965074da2bb1464b8b7e00ad51764cb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No statements were generated from the answer.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLM to evaluate: anthropic_claud3_haiku_answer\n",
      "Using 4 eval questions, Mean Score = 0.6172\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>contexts</th>\n",
       "      <th>answer</th>\n",
       "      <th>ground_truth</th>\n",
       "      <th>answer_relevancy</th>\n",
       "      <th>answer_similarity</th>\n",
       "      <th>answer_correctness</th>\n",
       "      <th>faithfulness</th>\n",
       "      <th>avg_answer_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What do the parameters for HNSW mean?</td>\n",
       "      <td>[HNSW (Hierarchical Navigable Small World Grap...</td>\n",
       "      <td>The parameters for HNSW (Hierarchical Navigabl...</td>\n",
       "      <td># M: maximum degree, or number of connections ...</td>\n",
       "      <td>0.749758</td>\n",
       "      <td>0.757881</td>\n",
       "      <td>0.689470</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.732370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are good default values for HNSW paramete...</td>\n",
       "      <td>[parameters vary with Milvus distribution. Sel...</td>\n",
       "      <td>I don't know. The context provided does not co...</td>\n",
       "      <td>M=16, efConstruction=32, ef=32</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.603921</td>\n",
       "      <td>0.150980</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.251634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>What does nlist vs nprobe mean in ivf_flat?</td>\n",
       "      <td>[FAQ  What is the difference between FLAT inde...</td>\n",
       "      <td>The nlist parameter in the IVF_FLAT index in M...</td>\n",
       "      <td># nlist:  controls how the vector data is part...</td>\n",
       "      <td>0.860034</td>\n",
       "      <td>0.769326</td>\n",
       "      <td>0.692331</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.773897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>What is the default AUTOINDEX distance metric ...</td>\n",
       "      <td>[FAQs  API reference  Similarity Metrics  In M...</td>\n",
       "      <td>The default AUTOINDEX distance metric in Milvu...</td>\n",
       "      <td>Trick answer:  IP inner product, not yet updat...</td>\n",
       "      <td>0.968280</td>\n",
       "      <td>0.631614</td>\n",
       "      <td>0.532904</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.710933</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            question  \\\n",
       "0              What do the parameters for HNSW mean?   \n",
       "1  What are good default values for HNSW paramete...   \n",
       "2        What does nlist vs nprobe mean in ivf_flat?   \n",
       "3  What is the default AUTOINDEX distance metric ...   \n",
       "\n",
       "                                            contexts  \\\n",
       "0  [HNSW (Hierarchical Navigable Small World Grap...   \n",
       "1  [parameters vary with Milvus distribution. Sel...   \n",
       "2  [FAQ  What is the difference between FLAT inde...   \n",
       "3  [FAQs  API reference  Similarity Metrics  In M...   \n",
       "\n",
       "                                              answer  \\\n",
       "0  The parameters for HNSW (Hierarchical Navigabl...   \n",
       "1  I don't know. The context provided does not co...   \n",
       "2  The nlist parameter in the IVF_FLAT index in M...   \n",
       "3  The default AUTOINDEX distance metric in Milvu...   \n",
       "\n",
       "                                        ground_truth  answer_relevancy  \\\n",
       "0  # M: maximum degree, or number of connections ...          0.749758   \n",
       "1                     M=16, efConstruction=32, ef=32          0.000000   \n",
       "2  # nlist:  controls how the vector data is part...          0.860034   \n",
       "3  Trick answer:  IP inner product, not yet updat...          0.968280   \n",
       "\n",
       "   answer_similarity  answer_correctness  faithfulness  avg_answer_score  \n",
       "0           0.757881            0.689470           1.0          0.732370  \n",
       "1           0.603921            0.150980           0.0          0.251634  \n",
       "2           0.769326            0.692331           0.4          0.773897  \n",
       "3           0.631614            0.532904           1.0          0.710933  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# TODO: Make this test harness easier to use.\n",
    "##########################################\n",
    "# EVALUATE_WHAT = 'ANSWERS' or 'CONTEXTS'\n",
    "##########################################\n",
    "EVALUATE_WHAT = 'ANSWERS' \n",
    "# EVALUATE_WHAT = 'CONTEXTS' \n",
    "\n",
    "# Possible LLM model choices to evaluate:\n",
    "# 1. openai gpt-3.5-turbo = 'Custom_RAG_answer'\n",
    "# 2. llama3_answer\n",
    "# 3. anthropic_claud3_haiku_answer\n",
    "LLM_TO_EVALUATE = 'Custom_RAG_answer'\n",
    "# LLM_TO_EVALUATE = 'llama3_answer'\n",
    "# LLM_TO_EVALUATE = 'mixtral_8x22b_instruct'\n",
    "# LLM_TO_EVALUATE = 'anthropic_claud3_haiku_answer'\n",
    "\n",
    "# Possible chunking strategies to evaluate:\n",
    "# 1. recursivetextsplitter = 'Simple'\n",
    "# 2. htmlsplitter = 'Custom_RAG_context'\n",
    "# 3. small_big_splitter\n",
    "# CONTEXT_TO_EVALUATE='Simple'\n",
    "CONTEXT_TO_EVALUATE='Custom_RAG_context'\n",
    "\n",
    "if EVALUATE_WHAT == 'ANSWERS':\n",
    "    CONTEXT_TO_EVALUATE='Custom_RAG_context'\n",
    "    eval_metrics=[\n",
    "        answer_relevancy,\n",
    "        answer_similarity,\n",
    "        answer_correctness,\n",
    "        faithfulness,\n",
    "        ]\n",
    "    metrics = ['answer_relevancy', 'answer_similarity', 'answer_correctness', 'faithfulness']\n",
    "\n",
    "elif EVALUATE_WHAT == 'CONTEXTS':\n",
    "    LLM_TO_EVALUATE = 'Custom_RAG_answer'\n",
    "    eval_metrics=[\n",
    "        context_recall, \n",
    "        context_precision,\n",
    "        faithfulness,\n",
    "        ]\n",
    "    metrics = ['context_recall', 'context_precision', 'faithfulness']\n",
    "    \n",
    "# Change the default the llm-as-critic.\n",
    "LLM_NAME = \"gpt-3.5-turbo\"\n",
    "ragas_llm = ragas.llms.llm_factory(model=LLM_NAME)\n",
    "\n",
    "# Change the default embeddings to HuggingFace models.\n",
    "EMB_NAME = \"BAAI/bge-large-en-v1.5\"\n",
    "lc_embeddings = HuggingFaceEmbeddings(model_name=EMB_NAME)\n",
    "ragas_emb = LangchainEmbeddingsWrapper(embeddings=lc_embeddings)\n",
    "\n",
    "# Change each metric.\n",
    "for metric in metrics:\n",
    "    globals()[metric].llm = ragas_llm\n",
    "    globals()[metric].embeddings = ragas_emb\n",
    "\n",
    "# Execute the evaluation.\n",
    "ragas_result, score = _eval_ragas.evaluate_ragas_model(\n",
    "    eval_df, eval_metrics, LLM_TO_EVALUATE, \n",
    "    CONTEXT_TO_EVALUATE, EVALUATE_WHAT)\n",
    "\n",
    "# Display the results.\n",
    "print(f\"Using {eval_df.shape[0]} eval questions, Mean Score = {score}\")\n",
    "display(ragas_result.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c408624",
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################\n",
    "# Avg Context Precision htmlsplitter score = 0.54 (200% improvement)\n",
    "# Avg Context Precision simple score = 0.17\n",
    "####################################################\n",
    "\n",
    "####################################################\n",
    "# Avg openai gpt-3.5-turbo score = 0.7365 (10% improvement)\n",
    "# Avg mistralai mixtral_8x22b_instruct score = 0.7272 (9% improvement)\n",
    "# Avg llama3 answer score = 0.667 (8% improvement)\n",
    "# Avg anthropic_claud3_haiku_answer score = 0.6172\n",
    "####################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "d0e81e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Drop collection\n",
    "utility.drop_collection(COLLECTION_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "c777937e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Author: Christy Bergman\n",
      "\n",
      "Python implementation: CPython\n",
      "Python version       : 3.11.8\n",
      "IPython version      : 8.22.2\n",
      "\n",
      "torch    : 2.2.2\n",
      "datasets : 2.19.0\n",
      "pymilvus : 2.4.1\n",
      "langchain: 0.1.16\n",
      "ollama   : 0.1.8\n",
      "anthropic: 0.25.6\n",
      "openai   : 1.14.3\n",
      "ragas    : 0.1.7\n",
      "\n",
      "conda environment: py311-unum\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Props to Sebastian Raschka for this handy watermark.\n",
    "# !pip install watermark\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -a 'Christy Bergman' -v -p torch,datasets,pymilvus,langchain,ollama,anthropic,openai,ragas --conda"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
